# Copyright (c) 2010 cycad <cycad@zetasquad.com>. All rights reserved.

from logging import DEBUG,INFO,ERROR,CRITICAL
import socket
import sys
import struct 
import random 
import traceback
import SubspaceEncryption
import time
import math
import select
from SubspaceFileChecksum import FileChecksum


def GetTickCountHs():
	"""Tick count in hundreths of a second."""
	return math.trunc(time.clock() * 100) & 0xFFFFFFFF


def TickDiff(now, base):
	"""Returns the difference between two 32-bit tick counts."""
	return (now - base) & 0xFFFFFFFF


PRIORITY_HIGH = 0	# for sending a high priority packet, like position updates or ack messages
PRIORITY_NORMAL = 1	# for sending normal priority packets

MAX_PACKET = 512	# the maximum packet size

RELIABLE_RETRANSMIT_INTERVAL = 50	# in hundreths of seconds

EVENT_TICK_INTERVAL = 10 # how often EVENT_TICK happens, in hs, backed by an accumulator
EVENT_INTERNAL_CORE_PERIODIC_INTERVAL = 10 # not backed by an accumulator

NO_PACKET_RECEIVED_TIMEOUT_INTERVAL = 1500 # after 15 seconds without receiving data, disconnect


# core events
EVENT_GAME_PACKET_RECEIVED = 1
EVENT_TICK = 2
EVENT_DISCONNECT = 3
	
class CoreEvent:
	"""Represents an event generated by the Core."""
	def __init__(self, type):
		self.type = type

class QueuedPacket:
	"""Represents a packet that is in queue to be sent on the network."""
	def __init__(self, data, reliable=False):
		self.data = data
		self.reliable = reliable
		if self.totalPacketSize() > MAX_PACKET:
			raise Exception("Packet has a size greater than the maximum allowed: %d" % self.totalPacketSize())
	
	def totalPacketSize(self):
		"""Compute the total packed size, including the size of the reliable header if necessary."""
		if self.reliable: return len(self.data) + 6
		else: return len(self.data)

class CoreStack:
	"""The core stack."""
	def __init__(self, debug=False,logger=None):
		self.__debug = debug
		self.logger = logger
		
		
		self.__socket = None
		self.__crypto = None # initialized after servers key is received. once set, encryption is done
		self.__packet_queues = [[], []]	# list containing QueuedPacket objects
		self.__reliable_messages_in_transit = { } # ackid : [packet, last transmit tick]   a list is used because it needs to be assigned (updating tick count) while the dictionary has an open iterator
		self.__incoming_reliable_packets = [] # (ack_id, packet) reliable messages that were received out of order, packet does not include the reliable header
		self.__event_list = [] # the list of pending CoreEvent objects
		
		self.__last_packet_received_tick = GetTickCountHs()
		
		# the packet handler functions
		self.__packet_handlers = {
			0x03 : self.__handleReliableMessage,
			0x04 : self.__handleAckPacket,
			0x05 : self.__handleSyncRequestPacket,
			0x06 : self.__handleSyncResponsePacket,
			0x07 : self.__handleDisconnectPacket,
			0x08 : self.__handleChunk,
			0x09 : self.__handleChunkEnd,
			0x0A : self.__handleHugeChunk,
			0x0D : self.__handleDisconnectPacket,	#xxx we could parse the reason out of this
			0x0E : self.__handleClusterPacket,
		}
		
		# the list of incoming chunk data that gets concatenated and then handled
		self.__chunk_list = None
		self.__huge_chunk_data = None
		
		#for the external loop will be set to false if we disconnect from server
		#ie !stopbot
		self.reconnect = True 
	
	
	def resetState(self):#for recop
		self.__socket = None
		self.__crypto = None # initialized after servers key is received. once set, encryption is done
		self.__packet_queues = [[], []]	# list containing QueuedPacket objects
		self.__reliable_messages_in_transit = { } # ackid : [packet, last transmit tick]   a list is used because it needs to be assigned (updating tick count) while the dictionary has an open iterator
		self.__incoming_reliable_packets = [] # (ack_id, packet) reliable messages that were received out of order, packet does not include the reliable header
		self.__event_list = [] # the list of pending CoreEvent objects
		
	def __log(self,level,message):
		if self.logger:
			self.logger.log(level,message)
		else:
			print (message)
	
		
	def __addPendingEvent(self, core_event):
		"""Adds CoreEvent to the list for handling."""
		self.__event_list.append(core_event)
		
	def connectToServer(self, server, port,newconn=1):
		"""Connect to the server, otherwise raise an exception."""

		self.__total_packets_sent = 0
		self.__total_packets_received = 0
		self.__next_outgoing_ack_id = 0
		self.__next_incoming_ack_id = 0
		self.__last_sync_response_received_tick = GetTickCountHs()
		
		self.__event_tick_tick_accumulator = 0
		self.__last_wait_for_event_call_tick = GetTickCountHs()
		self.__last_core_periodic_event_tick = GetTickCountHs()

		self.__server = str(server)
		self.__port = int(port)
		self.__socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
		self.__socket.connect((server, port))
		self.__socket.setblocking(False) #changed for jython support doesnt seem to affect python
		#self.__socket.settimeout(.01)
		self.__timeout_interval = 0.01

			
		STATE_NONE = 0
		STATE_CHALLENGE_RECEIVED = 1
		STATE_CONNECTED = 2
		
		client_key = (-random.randrange(1, sys.maxint))
		server_challenge = 0
		
		
		self.fc = FileChecksum()
		self.fc.generateChecksumArray(0)
		
		state = STATE_NONE
		timeouts = 0
		while state != STATE_CONNECTED:
			if timeouts > 3:
				break
			
			# send queueing the packets directly is done here because
			# handshake packets are special
			if state == STATE_NONE:
				# send handshake initiation packet
				self.queuePacket(struct.pack("<BBIH", 0x00, 0x01, client_key & 0xFFFFFFFF, 0x0001))
			elif state == STATE_CHALLENGE_RECEIVED:
				# respond with server challenge
				self.queuePacket(struct.pack("<BBII", 0x00, 0x06, server_challenge & 0xFFFFFFFF, GetTickCountHs()))
			
			# write sent packets to the network
			self.flushOutboundQueues()
			
			# read raw packet is used here because the packets during handshake
			# are a little bit special (such as the 0x02 type containing key data)
			packet = self.__readRawPacket(2)
			if packet is None:
				timeouts += 1
				continue
			
			try:
				game_type, core_type = struct.unpack_from("<BB", packet)
				
				# default to bad state, unless we get an expected packet
				state = STATE_NONE
				if game_type == 0x00:
					if core_type == 0x05:	# response to __sendClientKey()
						server_challenge, = struct.unpack_from("<i", packet, 2)
						state = STATE_CHALLENGE_RECEIVED
					elif core_type == 0x02:
						# success
						server_key, = struct.unpack_from("<i", packet, 2)
						self.__crypto = SubspaceEncryption.SubspaceEncryption(client_key, server_key)
						state = STATE_CONNECTED
						break
			except struct.error:
				pass
				
		if state != STATE_CONNECTED:
			self.__log(CRITICAL,"Unable to connect to server")
			self.__connected = False
		else:
			self.__connected = True
	
		
	def queuePacket(self, data, reliable=False, priority=PRIORITY_NORMAL):
		"""Queue a packet to be sent."""
		size = len(data)
		psize = size + (6 if reliable else 0)
		if psize < MAX_PACKET:
			self.__packet_queues[priority].append(QueuedPacket(data, reliable))
		else:
			if size < (MAX_PACKET*100):
				#print "chunk total size:" + str(len(data))
				MAX_CHUNK = MAX_PACKET - 8# 6 for reliable header + 2 for packet header
				while(len(data)> 0):
					chunk_size = min(MAX_CHUNK,len(data))
					remaining = len(data)- chunk_size
					if remaining >0:
						packet = struct.pack("<BB",0x00,0x08)
						#print "Headpacket_size: %d Remaining Data:%d" % (chunk_size, remaining)  
					else:
						packet = struct.pack("<BB",0x00,0x09) 
						#print "Tailpacket_size: %d Remaining Data:%d" % (chunk_size, remaining) 
					packet += data[0:chunk_size-1]
					data = data[chunk_size:]
					self.__packet_queues[priority].append(QueuedPacket(packet, True))	
			else:
				#dont think this works when i putfile it seems to fail if it gets so big as to use huge chunk
					offset =0 
					print "huge chunk total size:" + str(size)
					MAX_CHUNK = MAX_PACKET - 12# 6 for reliable header + 6 for packet header
					
					while(len(data)> 0):
						chunk_size = min(MAX_CHUNK,len(data))
						packet = struct.pack("<BBI",0x00,0x0a,size) 
						packet += data[0:chunk_size-1]
						data = data[chunk_size:]
						print "packet_size: %d Remaining Data:%d" % (chunk_size, len(data))
						self.__packet_queues[priority].append(QueuedPacket(packet, True)) 
	
	def __generateNextOutboundPacket(self):
		"""Extract packets from queues and return a buffer containing the next outbound packet
		that is to be sent on the network, if a packet can be built.
		"""
		h = self.__packet_queues[PRIORITY_HIGH]
		n = self.__packet_queues[PRIORITY_NORMAL]
		
		if len(h) == 0 and len(n) == 0:
			return None
				
		# determine which packets can be clustered, if any
		clusterable_packets = 0
		try:
			for packet_list in [h, n]:
				for outgoing_packet in packet_list:
					if outgoing_packet.totalPacketSize() > 255:
						raise StopIteration()
					clusterable_packets += 1
		except StopIteration:
			pass
		
		data = None
		if clusterable_packets > 1:
			data = struct.pack("<BB", 0x00, 0x0E)
			
			reliable_allowed = True
			for packet_list in [h, n]:
				i = 0
				for p in packet_list[:]:
					if p.totalPacketSize() <= 255 and len(data) + p.totalPacketSize() + 1 <= MAX_PACKET: #plus 1 is for the data length in cluster
						if p.reliable:
							#send reliable packet, can be disallowed due to ordering concerns
							if reliable_allowed:
								packet = struct.pack("<BBBI", p.totalPacketSize(), 0x00, 0x03, self.__next_outgoing_ack_id & 0xFFFFFFFF) + p.data
								
								# store off packet[1:] here so when its retransmitted, it doesnt re-append the 'size' field
								# since only the data should be resent
								self.__reliable_messages_in_transit[self.__next_outgoing_ack_id] = [packet[1:], GetTickCountHs()]
								self.__next_outgoing_ack_id += 1
								data += packet
								
								packet_list[i] = None
						else:
							# send an unreliable packet
							data += struct.pack('<B', p.totalPacketSize()) + p.data
							packet_list[i] = None
					elif p.reliable:
						# this packet wont fit and it is reliable, so dont allow following
						# reliable packets to be sent, because these need to be sent in order
						reliable_allowed = False
					i += 1
			
			# remove None entries
			# opt: should these be   self.__packet_queues[PRIORITY_Xxx][:] = ...   ?
			self.__packet_queues[PRIORITY_HIGH] = [x for x in h if x is not None]
			self.__packet_queues[PRIORITY_NORMAL] = [x for x in n if x is not None]
			
		else:
			# a single, non-cluster packet is being sent
			if len(h): p = h.pop(0)
			else: p = n.pop(0)
				
			data = p.data
			
			if p.reliable:
				# send packet with the reliable header prepended
				data = struct.pack("<BBI", 0x00, 0x03, self.__next_outgoing_ack_id & 0xFFFFFFFF) + data
				self.__reliable_messages_in_transit[self.__next_outgoing_ack_id] = [data, GetTickCountHs()]
				self.__next_outgoing_ack_id += 1
		
		return data
	
	def flushOutboundQueues(self):
		"""Flush outbound packet queues."""
		while 1:
			if len(self.__packet_queues[PRIORITY_HIGH]) == 0 and len(self.__packet_queues[PRIORITY_NORMAL]) == 0:
				break
			# check to make sure outbound socket is writable
			rlist, wlist, xlist = select.select([], [self.__socket], [], 0)
			if len(wlist) == 0: break
		
			packet = self.__generateNextOutboundPacket()
			if packet:
				self.__sendRawPacket(packet)
				
	def __queueAckPacket(self, ack_id):
		"""Queue an ack packet."""
		self.queuePacket(struct.pack("<BBI", 0x00, 0x04, ack_id & 0xFFFFFFFF))
	
	def __queueDisconnectPacket(self):
		"""Queue a disconnect packet."""
		self.queuePacket(struct.pack("<BB", 0x00, 0x07))
	
	def disconnectFromServer(self):
		"""Disconnect from the server."""
		self.__queueDisconnectPacket()
		self.__addPendingEvent(CoreEvent(EVENT_DISCONNECT))
		self.reconnect = False
	
	def shouldReconnect(self):
		return self.reconnect
		
	def __handleReliableMessage(self, packet):
		"""Handle an incoming reliable message."""
		# this could be cleaner by adding to the incoming list, then processing lists
		zero, core_type, ack_id = struct.unpack_from("<BBI", packet)
		
		self.__queueAckPacket(ack_id)
		
		if ack_id == self.__next_incoming_ack_id:
			self.__next_incoming_ack_id += 1
			self.__processIncomingPacket(packet[6:])
			
			loop_again = True
			while loop_again:
				loop_again = False
				index = 0
				for ack_packet_tuple in self.__incoming_reliable_packets:
					if ack_packet_tuple[0] == self.__next_incoming_ack_id:
						self.__incoming_reliable_packets.pop(index)
						self.__next_incoming_ack_id += 1
						self.__processIncomingPacket(ack_packet_tuple[1])
						loop_again = True
						break
					index += 1
			
		elif ack_id > self.__next_incoming_ack_id:
			self.__incoming_reliable_packets.append((ack_id, packet[6:]))
	
	def __handleChunk(self, packet):
		if self.__chunk_list is None:
			self.__chunk_list = []
			#print "small chunk type: " + packet.encode('hex')
		self.__chunk_list.append(packet[2:])
	
	def __handleChunkEnd(self, packet):
		if self.__chunk_list:
			self.__chunk_list.append(packet[2:])
			self.__processIncomingPacket(''.join(self.__chunk_list))
			self.__chunk_list = None
	
	#added by junky
	def __handleHugeChunk(self,packet):
		type,type2, total_size = struct.unpack_from("<BBI",packet)
		if(self.__huge_chunk_data == None): #new chunk
			self.__huge_chunk_data = packet[6:]
			print "huge chunk type: " + self.__huge_chunk_data.encode('hex')
		else:
			self.__huge_chunk_data += packet[6:] 
			
		if( len(self.__huge_chunk_data) == total_size): #packet complete
			self.__processIncomingPacket(self.__huge_chunk_data[:])
			self.__huge_chunk_data = None


			
			
	
	def waitForEvent(self):
		"""Wait for an event to occur.  If no event is immediately available this
		call blocks.  Must be called frequently for good performance."""
		while self.__connected:
			self.__processIncomingPackets()
			self.flushOutboundQueues()
			
			# process tick event if necessary
			now = GetTickCountHs()
			self.__event_tick_tick_accumulator += now - self.__last_wait_for_event_call_tick
			self.__last_wait_for_event_call_tick = now
			while self.__event_tick_tick_accumulator > EVENT_TICK_INTERVAL:
				self.__addPendingEvent(CoreEvent(EVENT_TICK))
				self.__event_tick_tick_accumulator -= EVENT_TICK_INTERVAL
			
			# create internal periodic core event
			if TickDiff(now,  self.__last_core_periodic_event_tick) > EVENT_INTERNAL_CORE_PERIODIC_INTERVAL:
				self.__corePeriodicEvent()
				self.__last_core_periodic_event_tick = now
			
			# process pending events first
			if len(self.__event_list) > 0:
				# preprocess the event
				event = self.__event_list.pop(0)
				if event.type == EVENT_DISCONNECT:
						self.__connected = False
				return event
			
			# nothing left to do, wait on an event
			timeout = float(EVENT_TICK_INTERVAL - self.__event_tick_tick_accumulator) / 100
			# if there are no packets to be sent, just wait for a read
			# otherwise wait for a read or the outbound socket to be writable
			if len(self.__packet_queues[PRIORITY_HIGH]) == 0 and len(self.__packet_queues[PRIORITY_NORMAL]) == 0:
				select.select([self.__socket], [], [], timeout)
			else:
				select.select([self.__socket], [self.__socket], [], timeout)

	def __processIncomingPackets(self):
		"""Process all incoming packets."""
		# process incoming packets, if any exist
		while True:
			packet = self.__readRawPacket(0)
			if packet is None: break
			self.__processIncomingPacket(packet)
			
	def __corePeriodicEvent(self):
		"""This is called every 100ms"""
		# requeue reliable messages that havent been acked
		now = GetTickCountHs()
		for ack_id, packet_last_transmit_tick_list in self.__reliable_messages_in_transit.iteritems():
			if TickDiff(now, packet_last_transmit_tick_list[1]) > RELIABLE_RETRANSMIT_INTERVAL:
				self.queuePacket(packet_last_transmit_tick_list[0], False, PRIORITY_HIGH)
				packet_last_transmit_tick_list[1] = now
		
		if TickDiff(now, self.__last_packet_received_tick) >= NO_PACKET_RECEIVED_TIMEOUT_INTERVAL:
			self.__addPendingEvent(CoreEvent(EVENT_DISCONNECT))
	
	def __handleClusterPacket(self, packet):
		packet = packet[2:]
		while len(packet):
			data_len, = struct.unpack_from("<B", packet)
			self.__processIncomingPacket(packet[1:data_len + 1])
			packet = packet[data_len + 1:]
	
	def __handleDisconnectPacket(self, packet):
		self.__addPendingEvent(CoreEvent(EVENT_DISCONNECT))
	
	def __handleSyncRequestPacket(self, packet):
		self.__queueSyncResponse()
	
	def __handleSyncResponsePacket(self, packet):
		self.__last_sync_response_received_tick = GetTickCountHs()
		
	def __handleAckPacket(self, packet):
		zero, type, ack_id = struct.unpack_from("<BBI", packet)
		self.__reliable_messages_in_transit.pop(ack_id, None)
		
	def __processIncomingPacket(self, packet):
		# process the incoming packet, etc
		try:
			type, = struct.unpack_from("<B", packet)
			if type == 0x00:
				# core packet handlers
				type, = struct.unpack_from("<B", packet, 1)
				if self.__debug:
					self.__log(DEBUG, "Handling Core Type: 0x%02X" % type)
					
				handler = self.__packet_handlers.get(type, None)
				if handler:
					handler(packet)
				else:
					self.__log(INFO,"wtf corestack type %i not handled"%(type,))
			else:
				if self.__debug:
					self.__log(DEBUG, "Handling Game Type: 0x%02X" % type)
				event = CoreEvent(EVENT_GAME_PACKET_RECEIVED)
				event.packet = packet
				self.__addPendingEvent(event)
			
		except (IndexError, struct.error):
			self.__log(CRITICAL, "Error in packet processing")
			self.__log(CRITICAL, "Packet data: " + packet.encode('hex'))
			formatted_lines = traceback.format_exc().splitlines()
			for l in formatted_lines:
				self.__log(DEBUG,l)
			
	def __readRawPacket(self, timeout):
		"""Read a raw packet on the network, optionally blocking on the socket."""
		rlist, wlist, xlist = select.select([self.__socket], [], [], timeout)
		if len(rlist) == 0:
			time.sleep(0.01)
			return None

#		try:
#			if self.__timeout_interval != timeout:
#				self.__timeout_interval = timeout
#				self.__socket.settimeout(self.__timeout_interval)
#			packet = self.__socket.recv(MAX_PACKET)
#		except socket.timeout as e:
#			#self.logger.info("timeout")
#			return
		packet = rlist[0].recv(MAX_PACKET)
	
		self.__total_packets_received += 1
		
		# decrypt the packet
		if self.__crypto:
			try:
				type, = struct.unpack_from("<B", packet)
				begin_offset = 1
				if type == 0x00:
					begin_offset = 2
				packet = packet[:begin_offset] + self.__crypto.decryptData(packet[begin_offset:])
					
			except (IndexError, struct.error):
				packet = None
		
		if packet:
			if self.__debug:
				self.__log(DEBUG, "Read:" + packet.encode('hex'))
		
			self.__last_packet_received_tick = GetTickCountHs()
			
		return packet
	
	def _queueSyncRequest(self):
		self.queuePacket(struct.pack("<BBIII", 0x00, 0x05, GetTickCountHs(), self.__total_packets_sent & 0xFFFFFFFF, self.__total_packets_received & 0xFFFFFFFF), priority=PRIORITY_HIGH)
	
	def __queueSyncResponse(self):
		self.queuePacket(struct.pack("<BBII", 0x00, 0x06, GetTickCountHs(), TickDiff(GetTickCountHs(), self.__last_sync_response_received_tick)), priority=PRIORITY_HIGH)
	
	#def __handleSyncResponsePacket(self, packet):
	#	# packet contents are irrelevant
	#	self.__last_sync_response_received_tick = GetTickCountHs()
	
	def __queueAckMessage(self, ack_id):
		self.queuePacket(struct.pack("<BBI", 0x00, 0x03, ack_id & 0xFFFFFFFF), priority=PRIORITY_HIGH)
	
	def __sendRawPacket(self, packet):
		if self.__debug:
			self.__log(DEBUG, "Sent:" + packet.encode('hex'))
			
		if self.__crypto:
			begin_offset = 1
			if ord(packet[0]) == 0x00:
				begin_offset = 2
			
			packet = packet[:begin_offset] + self.__crypto.encryptData(packet[begin_offset:])
			
		self.__socket.sendall(packet)
		self.__total_packets_sent += 1
